{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import transformer\n",
    "from config import get_config, get_weights_file_path\n",
    "from train import get_model, get_dataset\n",
    "from validation import greed_decode\n",
    "\n",
    "import altair as alt\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "print(f'use device: {device}')\n",
    "config = get_config()\n",
    "train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_dataset(config)\n",
    "model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)\n",
    "\n",
    "model_filename = get_weights_file_path(config, f\"29\")\n",
    "state = torch.load(model_filename)\n",
    "model.load_state_dict(state['model_state_dict'])\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
